{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "17cd36c0-d807-48cb-a139-24b465b656e9",
   "metadata": {},
   "source": [
    "# Scalable nearest genome search with Bacformer\n",
    "\n",
    "This tutorial outlines how one can use Bacformer embeddings and the associated corpus of over 1.3M genomes to search for similar genomes.\n",
    "\n",
    "We provide a precomputed dataset of genome embeddings from Bacformer for each genome from a variety of sources with associated metadata.\n",
    "\n",
    "The user can input a query genome, embed it with Bacformer and perform scalable, fast search for $k$ closest genomes (L2 distance) from the dataset of over 1.3M genomes. The nearest genomes can be then used to annotate the query genome and investigate similar ones.\n",
    "\n",
    "Before you start, make sure you have bacformer installed the [datasets>=2.21](https://pypi.org/project/datasets/) and [faiss-cpu](https://pypi.org/project/faiss-cpu/) packages. This can be done with a command: \n",
    "```python\n",
    "pip install -U \"datasets[faiss]>=2.21\"\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "be16bd36-9daa-4a18-bda5-ec15c81a30aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "615b4ead-c3c1-4778-8358-213caf67ad60",
   "metadata": {},
   "source": [
    "## Step 1: Download the dataset with Bacformer genome embeddings\n",
    "\n",
    "Download the precomputed dataset with Bacformer genome embeddings and the associated metadata."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7c2ce0e3-e4ff-4f19-8688-2d5fbd5e4ecd",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = load_dataset(\"macwiatrak/bacformer-genome-embeddings-corpus\", split=\"test\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42d2acf2-8404-411c-acd0-8350c372903f",
   "metadata": {},
   "source": [
    "## Step 2: Compute the FAISS index for scalable, fast vector search\n",
    "\n",
    "Use the [faiss](https://github.com/facebookresearch/faiss) package to compute an index which allows for efficient similarity search."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3dfd7b67-8cf7-4790-b810-08c36035b3e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1364/1364 [00:07<00:00, 187.80it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['genome_name', 'source', 'n_seqs', 'n_contigs', 'genome_length', 'n50', 'completeness', 'contamination', 'derived_from_sample', 'env', 'phylum', 'class', 'order', 'family', 'genus', 'species', 'env_clean', 'taxid', 'genome_embedding', '__index_level_0__'],\n",
       "    num_rows: 1363276\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds.add_faiss_index(column=\"genome_embedding\") "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5584a26c-239b-4a91-94f5-474696ae736e",
   "metadata": {},
   "source": [
    "## Step 3: Get the $k$ most similar genomes\n",
    "\n",
    "Fetch a query vector from the existing dataset. Using the query vector and the FAISS index computed in step 2. Search for similar genomes in our dataset and view the metadata."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "94e70551-7d2e-46aa-a752-31d7dea5fce1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>genome_name</th>\n",
       "      <th>source</th>\n",
       "      <th>n_seqs</th>\n",
       "      <th>n_contigs</th>\n",
       "      <th>genome_length</th>\n",
       "      <th>n50</th>\n",
       "      <th>completeness</th>\n",
       "      <th>contamination</th>\n",
       "      <th>derived_from_sample</th>\n",
       "      <th>env</th>\n",
       "      <th>phylum</th>\n",
       "      <th>class</th>\n",
       "      <th>order</th>\n",
       "      <th>family</th>\n",
       "      <th>genus</th>\n",
       "      <th>species</th>\n",
       "      <th>env_clean</th>\n",
       "      <th>taxid</th>\n",
       "      <th>__index_level_0__</th>\n",
       "      <th>l2_distance</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>spire_mag_00829374</td>\n",
       "      <td>SPIRE</td>\n",
       "      <td>1965.0</td>\n",
       "      <td>25.0</td>\n",
       "      <td>2353457.0</td>\n",
       "      <td>124487.0</td>\n",
       "      <td>99.83</td>\n",
       "      <td>5.29</td>\n",
       "      <td>SAMN07491241</td>\n",
       "      <td>host-associated:animal host:mammalian host:hum...</td>\n",
       "      <td>Actinobacteriota</td>\n",
       "      <td>Actinomycetia</td>\n",
       "      <td>Actinomycetales</td>\n",
       "      <td>Bifidobacteriaceae</td>\n",
       "      <td>Bifidobacterium</td>\n",
       "      <td>Bifidobacterium longum</td>\n",
       "      <td>Human host</td>\n",
       "      <td>None</td>\n",
       "      <td>736918</td>\n",
       "      <td>0.054596</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>spire_mag_01587903</td>\n",
       "      <td>SPIRE</td>\n",
       "      <td>1803.0</td>\n",
       "      <td>27.0</td>\n",
       "      <td>2135822.0</td>\n",
       "      <td>105354.0</td>\n",
       "      <td>91.63</td>\n",
       "      <td>0.03</td>\n",
       "      <td>SAMN08993539</td>\n",
       "      <td>host-associated:animal host:mammalian host:hum...</td>\n",
       "      <td>Actinobacteriota</td>\n",
       "      <td>Actinomycetia</td>\n",
       "      <td>Actinomycetales</td>\n",
       "      <td>Bifidobacteriaceae</td>\n",
       "      <td>Bifidobacterium</td>\n",
       "      <td>Bifidobacterium longum</td>\n",
       "      <td>Human host</td>\n",
       "      <td>None</td>\n",
       "      <td>751325</td>\n",
       "      <td>0.054921</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>spire_mag_02980720</td>\n",
       "      <td>SPIRE</td>\n",
       "      <td>1832.0</td>\n",
       "      <td>27.0</td>\n",
       "      <td>2211833.0</td>\n",
       "      <td>133051.0</td>\n",
       "      <td>99.95</td>\n",
       "      <td>0.02</td>\n",
       "      <td>SAMN17719258</td>\n",
       "      <td>host-associated:animal host:mammalian host:hum...</td>\n",
       "      <td>Actinobacteriota</td>\n",
       "      <td>Actinomycetia</td>\n",
       "      <td>Actinomycetales</td>\n",
       "      <td>Bifidobacteriaceae</td>\n",
       "      <td>Bifidobacterium</td>\n",
       "      <td>Bifidobacterium longum</td>\n",
       "      <td>Human host</td>\n",
       "      <td>None</td>\n",
       "      <td>385705</td>\n",
       "      <td>0.061194</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>spire_mag_02997990</td>\n",
       "      <td>SPIRE</td>\n",
       "      <td>1822.0</td>\n",
       "      <td>27.0</td>\n",
       "      <td>2203378.0</td>\n",
       "      <td>104325.0</td>\n",
       "      <td>99.92</td>\n",
       "      <td>0.03</td>\n",
       "      <td>SAMEA14101315</td>\n",
       "      <td>host-associated:animal host:mammalian host:hum...</td>\n",
       "      <td>Actinobacteriota</td>\n",
       "      <td>Actinomycetia</td>\n",
       "      <td>Actinomycetales</td>\n",
       "      <td>Bifidobacteriaceae</td>\n",
       "      <td>Bifidobacterium</td>\n",
       "      <td>Bifidobacterium longum</td>\n",
       "      <td>Human host</td>\n",
       "      <td>None</td>\n",
       "      <td>397971</td>\n",
       "      <td>0.064117</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>spire_mag_01286748</td>\n",
       "      <td>SPIRE</td>\n",
       "      <td>1856.0</td>\n",
       "      <td>31.0</td>\n",
       "      <td>2243106.0</td>\n",
       "      <td>111647.0</td>\n",
       "      <td>99.95</td>\n",
       "      <td>0.02</td>\n",
       "      <td>SAMN09259925</td>\n",
       "      <td>host-associated:animal host:mammalian host:hum...</td>\n",
       "      <td>Actinobacteriota</td>\n",
       "      <td>Actinomycetia</td>\n",
       "      <td>Actinomycetales</td>\n",
       "      <td>Bifidobacteriaceae</td>\n",
       "      <td>Bifidobacterium</td>\n",
       "      <td>Bifidobacterium longum</td>\n",
       "      <td>Human host</td>\n",
       "      <td>None</td>\n",
       "      <td>637493</td>\n",
       "      <td>0.065715</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          genome_name source  n_seqs  n_contigs  genome_length       n50  \\\n",
       "1  spire_mag_00829374  SPIRE  1965.0       25.0      2353457.0  124487.0   \n",
       "2  spire_mag_01587903  SPIRE  1803.0       27.0      2135822.0  105354.0   \n",
       "3  spire_mag_02980720  SPIRE  1832.0       27.0      2211833.0  133051.0   \n",
       "4  spire_mag_02997990  SPIRE  1822.0       27.0      2203378.0  104325.0   \n",
       "5  spire_mag_01286748  SPIRE  1856.0       31.0      2243106.0  111647.0   \n",
       "\n",
       "   completeness  contamination derived_from_sample  \\\n",
       "1         99.83           5.29        SAMN07491241   \n",
       "2         91.63           0.03        SAMN08993539   \n",
       "3         99.95           0.02        SAMN17719258   \n",
       "4         99.92           0.03       SAMEA14101315   \n",
       "5         99.95           0.02        SAMN09259925   \n",
       "\n",
       "                                                 env            phylum  \\\n",
       "1  host-associated:animal host:mammalian host:hum...  Actinobacteriota   \n",
       "2  host-associated:animal host:mammalian host:hum...  Actinobacteriota   \n",
       "3  host-associated:animal host:mammalian host:hum...  Actinobacteriota   \n",
       "4  host-associated:animal host:mammalian host:hum...  Actinobacteriota   \n",
       "5  host-associated:animal host:mammalian host:hum...  Actinobacteriota   \n",
       "\n",
       "           class            order              family            genus  \\\n",
       "1  Actinomycetia  Actinomycetales  Bifidobacteriaceae  Bifidobacterium   \n",
       "2  Actinomycetia  Actinomycetales  Bifidobacteriaceae  Bifidobacterium   \n",
       "3  Actinomycetia  Actinomycetales  Bifidobacteriaceae  Bifidobacterium   \n",
       "4  Actinomycetia  Actinomycetales  Bifidobacteriaceae  Bifidobacterium   \n",
       "5  Actinomycetia  Actinomycetales  Bifidobacteriaceae  Bifidobacterium   \n",
       "\n",
       "                  species   env_clean taxid  __index_level_0__  l2_distance  \n",
       "1  Bifidobacterium longum  Human host  None             736918     0.054596  \n",
       "2  Bifidobacterium longum  Human host  None             751325     0.054921  \n",
       "3  Bifidobacterium longum  Human host  None             385705     0.061194  \n",
       "4  Bifidobacterium longum  Human host  None             397971     0.064117  \n",
       "5  Bifidobacterium longum  Human host  None             637493     0.065715  "
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query_vector = np.array(ds[0]['genome_embedding'])\n",
    "# set k (i.e. number of most similar genomes to retrieve)\n",
    "k = 10\n",
    "# get k most similar genomes\n",
    "dists, examples = ds.get_nearest_examples(\n",
    "        \"genome_embedding\",\n",
    "        query_vector,\n",
    "        k=k\n",
    "    )\n",
    "# convert the examples to a DF, exclude the first item as it's the genome embedding itself\n",
    "examples = pd.DataFrame(examples)[1:].drop(columns=['genome_embedding'])\n",
    "# add L2 distances\n",
    "examples['l2_distance'] = dists[1:]\n",
    "examples.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d03541fd-ca82-49bf-b261-d1a9f41596d2",
   "metadata": {},
   "source": [
    "----------------------\n",
    "\n",
    "#### Voilà, you made it 👏! \n",
    "\n",
    "In case of any issues or questions raise an issue on github - https://github.com/macwiatrak/Bacformer/issues."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57ea3934-412f-492d-bf74-9ae924ea9b1b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
